{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 在上一节中，我们简单地实现了一个十分十分基础甚至有些简陋的GPT，同时起生成效果看起来也有很大的提升空间\n",
    "- 这一节中，我们将对通过一系列的推导来向大家引入可以增强性能的自注意力机制\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自注意力机制怎么增强性能\n",
    "\n",
    "在此之前，nn.Embedding致力于将现有的编码转化为其对应的下一位的编码\n",
    "\n",
    "但是一个很重要的点是其忽略了现有的编码中彼此之间的联系，\n",
    "\n",
    "如果可以利用好这份联系，使得每个字的编码可以**相互通信**，是从能产生更好的性能呢"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 8, 2])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 此时我们以一个真实的情况为例，通过随机生成一些数据来代表当前的真实情况\n",
    "torch.manual_seed(42)   # 设置固定的种子，使得结果可以及时的复现\n",
    "B,T,C = 4,8,2 # batch, time, channels\n",
    "x = torch.randn(B,T,C)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 最简单的通信方式，第五个编码可以很简单地与收到前面四个编码的平均影响，虽然这种通信的方式听起来也十分地弱"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.9269, 1.4873]])\n",
      "这是上面的均值累加：\n",
      "tensor([1.9269, 1.4873])\n",
      "----------------------\n",
      "tensor([[ 1.9269,  1.4873],\n",
      "        [ 0.9007, -2.1055]])\n",
      "这是上面的均值累加：\n",
      "tensor([ 1.4138, -0.3091])\n",
      "----------------------\n",
      "tensor([[ 1.9269,  1.4873],\n",
      "        [ 0.9007, -2.1055],\n",
      "        [ 0.6784, -1.2345]])\n",
      "这是上面的均值累加：\n",
      "tensor([ 1.1687, -0.6176])\n",
      "----------------------\n",
      "tensor([[ 1.9269,  1.4873],\n",
      "        [ 0.9007, -2.1055],\n",
      "        [ 0.6784, -1.2345],\n",
      "        [-0.0431, -1.6047]])\n",
      "这是上面的均值累加：\n",
      "tensor([ 0.8657, -0.8644])\n",
      "----------------------\n",
      "tensor([[ 1.9269,  1.4873],\n",
      "        [ 0.9007, -2.1055],\n",
      "        [ 0.6784, -1.2345],\n",
      "        [-0.0431, -1.6047],\n",
      "        [-0.7521,  1.6487]])\n",
      "这是上面的均值累加：\n",
      "tensor([ 0.5422, -0.3617])\n",
      "----------------------\n",
      "tensor([[ 1.9269,  1.4873],\n",
      "        [ 0.9007, -2.1055],\n",
      "        [ 0.6784, -1.2345],\n",
      "        [-0.0431, -1.6047],\n",
      "        [-0.7521,  1.6487],\n",
      "        [-0.3925, -1.4036]])\n",
      "这是上面的均值累加：\n",
      "tensor([ 0.3864, -0.5354])\n",
      "----------------------\n",
      "tensor([[ 1.9269,  1.4873],\n",
      "        [ 0.9007, -2.1055],\n",
      "        [ 0.6784, -1.2345],\n",
      "        [-0.0431, -1.6047],\n",
      "        [-0.7521,  1.6487],\n",
      "        [-0.3925, -1.4036],\n",
      "        [-0.7279, -0.5594]])\n",
      "这是上面的均值累加：\n",
      "tensor([ 0.2272, -0.5388])\n",
      "----------------------\n",
      "tensor([[ 1.9269,  1.4873],\n",
      "        [ 0.9007, -2.1055],\n",
      "        [ 0.6784, -1.2345],\n",
      "        [-0.0431, -1.6047],\n",
      "        [-0.7521,  1.6487],\n",
      "        [-0.3925, -1.4036],\n",
      "        [-0.7279, -0.5594],\n",
      "        [-0.7688,  0.7624]])\n",
      "这是上面的均值累加：\n",
      "tensor([ 0.1027, -0.3762])\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n",
      "这是上面的均值累加：\n",
      "----------------------\n"
     ]
    }
   ],
   "source": [
    "# 第一种方式：为了循环和聚合，我们使用torch.mean函数进行操作\n",
    "\n",
    "xbow = torch.zeros((B,T,C))\n",
    "\n",
    "for b in range(B):  # 遍历所有的batch\n",
    "    for t in range(T):\n",
    "        # 使用切片操作x[b,:t+1]来获取x在第b个批次中前t+1个时间步的所有元素，得到一个形状为(t,C)的张量xprev。\n",
    "        xprev = x[b,:t+1] # (t,C)\n",
    "        # 使用torch.mean(xprev, 0)来计算xprev在第一个维度（dim=0）上的平均值。这个操作会返回一个形状为(C,)的张量，它的每个元素是xprev在对应列上的元素的平均值。\n",
    "        xbow[b,t] = torch.mean(xprev, 0) \n",
    "        \n",
    "        # print(b) if b==0 else None\n",
    "        # print(t) if b==0 else None\n",
    "        print(xprev) if b==0 else None   # 前面的累加\n",
    "        print('这是上面的均值累加：')\n",
    "        print(xbow[b,t]) if b==0 else None\n",
    "        print(\"----------------------\")\n",
    "\n",
    "# 虽然这样是可以行的，但是这里的实现方式明显可以进一步改进"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a=\n",
      "tensor([[1.0000, 0.0000, 0.0000],\n",
      "        [0.5000, 0.5000, 0.0000],\n",
      "        [0.3333, 0.3333, 0.3333]])\n",
      "--\n",
      "b=\n",
      "tensor([[0., 1.],\n",
      "        [3., 0.],\n",
      "        [1., 1.]])\n",
      "--\n",
      "c=\n",
      "tensor([[0.0000, 1.0000],\n",
      "        [1.5000, 0.5000],\n",
      "        [1.3333, 0.6667]])\n"
     ]
    }
   ],
   "source": [
    "# 方法2：使用矩阵乘法\n",
    "\n",
    "a = torch.tril(torch.ones(3, 3))   # 创建一个下三角的矩阵的函数\n",
    "# torch.sum 计算张量a在每一横行上的值，\n",
    "a = a / torch.sum(a, 1, keepdim=True)    # ！ 这一步相当于是在\n",
    "b = torch.randint(0,10,(3,2)).float()\n",
    "c = a @ b\n",
    "print('a=')\n",
    "print(a)\n",
    "print('--')\n",
    "print('b=')\n",
    "print(b)\n",
    "print('--')\n",
    "print('c=')\n",
    "print(c)\n",
    "\n",
    "#  在这个过程中，实现了字符根据权重得到最终的结果"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* 这个时候可以揭晓，我们所想要平均的一直是**权重**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight = torch.tril(torch.ones(T, T))\n",
    "weight = weight / weight.sum(1, keepdim=True)  # 直接是在第一个维度上进行加和\n",
    "weight\n",
    "# 而在这个例子中的b，其实是x\n",
    "xbow2 = weight @ x  # (B,T,T) @ (B,T,C)   ------>   (B,T,C)\n",
    "torch.allclose(xbow,xbow2)   # 这个是用于检测两个张量是否在一定的容忍度内是相等的"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "结果为`True`，所以说明这样几行就解决了上面这个循环要做的事情"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* 然而，这里还有一种更为巧妙的方式可以实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "# 第三种方式：使用softmax\n",
    "trils = torch.tril(torch.ones(T,T))\n",
    "\n",
    "weight = torch.zeros((T,T))  # 构造一个全为0的向量\n",
    "weight = weight.masked_fill(trils == 0,float('-inf'))  # 使所有tril为0的位置都变为无穷大\n",
    "# 然后，我们选择在每行的维度上去使用sotfmax，\n",
    "weight = F.softmax(weight,dim=-1)\n",
    "\n",
    "xbow3 = weight @ x\n",
    "\n",
    "torch.allclose(xbow,xbow3)   # 这个是用于检测两个张量是否在一定的容忍度内是相等的\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 8, 16])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 第4种方式 : 自注意力方式\n",
    "torch.manual_seed(1337)\n",
    "B,T,C = 4,8,32 # batch, time, channels\n",
    "x = torch.randn(B,T,C)\n",
    "\n",
    "# 一个简单的单头注意力机制示范\n",
    "head_size = 16\n",
    "key = nn.Linear(C, head_size, bias=False)\n",
    "query = nn.Linear(C, head_size, bias=False)\n",
    "value = nn.Linear(C, head_size, bias=False)\n",
    "k = key(x)   # (B, T, 16)\n",
    "q = query(x) # (B, T, 16)\n",
    "wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n",
    "\n",
    "tril = torch.tril(torch.ones(T, T))\n",
    "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
    "wei = F.softmax(wei, dim=-1)\n",
    "\n",
    "v = value(x)\n",
    "out = wei @ v\n",
    "\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-1.5713e-01,  8.8009e-01,  1.6152e-01, -7.8239e-01, -1.4289e-01,\n",
       "           7.4676e-01,  1.0068e-01, -5.2395e-01, -8.8726e-01,  1.9068e-01,\n",
       "           1.7616e-01, -5.9426e-01, -4.8124e-01, -4.8598e-01,  2.8623e-01,\n",
       "           5.7099e-01],\n",
       "         [ 6.7643e-01, -5.4770e-01, -2.4780e-01,  3.1430e-01, -1.2799e-01,\n",
       "          -2.9521e-01, -4.2962e-01, -1.0891e-01, -4.9282e-02,  7.2679e-01,\n",
       "           7.1296e-01, -1.1639e-01,  3.2665e-01,  3.4315e-01, -7.0975e-02,\n",
       "           1.2716e+00],\n",
       "         [ 4.8227e-01, -1.0688e-01, -4.0555e-01,  1.7696e-01,  1.5811e-01,\n",
       "          -1.6967e-01,  1.6217e-02,  2.1509e-02, -2.4903e-01, -3.7725e-01,\n",
       "           2.7867e-01,  1.6295e-01, -2.8951e-01, -6.7610e-02, -1.4162e-01,\n",
       "           1.2194e+00],\n",
       "         [ 1.9708e-01,  2.8561e-01, -1.3028e-01, -2.6552e-01,  6.6781e-02,\n",
       "           1.9535e-01,  2.8073e-02, -2.4511e-01, -4.6466e-01,  6.9287e-02,\n",
       "           1.5284e-01, -2.0324e-01, -2.4789e-01, -1.6213e-01,  1.9474e-01,\n",
       "           7.6778e-01],\n",
       "         [ 2.5104e-01,  7.3457e-01,  5.9385e-01,  2.5159e-01,  2.6064e-01,\n",
       "           7.5820e-01,  5.5947e-01,  3.5387e-01, -5.9338e-01, -1.0807e+00,\n",
       "          -3.1110e-01, -2.7809e-01, -9.0541e-01,  1.3181e-01, -1.3818e-01,\n",
       "           6.3715e-01],\n",
       "         [ 3.4277e-01,  4.9605e-01,  4.7248e-01,  3.0277e-01,  1.8440e-01,\n",
       "           5.8144e-01,  3.8245e-01,  2.9521e-01, -4.8969e-01, -7.7051e-01,\n",
       "          -1.1721e-01, -2.5412e-01, -6.8921e-01,  1.9795e-01, -1.5135e-01,\n",
       "           7.6659e-01],\n",
       "         [ 1.8658e-01, -9.6351e-02, -1.4300e-01,  3.0587e-01,  8.3441e-02,\n",
       "          -6.8646e-03, -2.0472e-01, -1.5350e-01, -7.6250e-02,  3.2689e-01,\n",
       "           3.0896e-01,  7.6626e-02,  9.9243e-02,  1.6560e-01,  1.9745e-01,\n",
       "           7.6248e-01],\n",
       "         [ 1.3013e-01, -3.2832e-02, -4.9645e-01,  2.8652e-01,  2.7042e-01,\n",
       "          -2.6357e-01, -7.3756e-02,  3.7857e-01,  7.4580e-02,  3.3827e-02,\n",
       "           1.4695e-02,  3.1937e-01,  2.9926e-01, -1.6530e-01, -3.8630e-02,\n",
       "           3.3748e-01]],\n",
       "\n",
       "        [[-1.3254e+00,  1.1236e+00,  2.2927e-01, -2.9970e-01, -7.6267e-03,\n",
       "           7.9364e-01,  8.9581e-01,  3.9650e-01, -6.6613e-01, -2.1844e-01,\n",
       "          -1.3539e+00,  4.1245e-01,  9.6011e-01, -1.0805e+00, -3.9751e-01,\n",
       "          -4.4439e-01],\n",
       "         [-3.8338e-01, -1.9659e-01,  8.8455e-02,  1.8560e-01, -8.7010e-02,\n",
       "           1.3239e-01,  3.0841e-01, -2.4350e-01, -1.9396e-01, -1.7634e-02,\n",
       "           4.8439e-01,  5.4210e-01, -2.0407e-02, -4.2467e-01, -2.3463e-01,\n",
       "          -4.6465e-01],\n",
       "         [-1.1100e+00,  3.2334e-01,  4.7054e-01, -6.3595e-02,  2.5443e-01,\n",
       "           1.5352e-01,  2.5186e-01,  2.6286e-01,  2.7916e-01, -3.1662e-03,\n",
       "          -3.2881e-02,  4.8191e-01,  7.4431e-01, -1.9921e-01,  2.7134e-01,\n",
       "          -8.5871e-02],\n",
       "         [-9.7190e-01,  4.6124e-01,  4.2349e-01, -1.7230e-02,  1.5847e-01,\n",
       "           4.1175e-01,  4.0764e-01,  2.4982e-01, -5.0322e-02,  4.1514e-03,\n",
       "          -3.9853e-01,  4.3551e-01,  7.0285e-01, -4.3081e-01,  2.6684e-02,\n",
       "          -2.0169e-01],\n",
       "         [ 3.3586e-01, -8.5915e-02,  9.3660e-01,  7.7311e-01,  1.8037e-01,\n",
       "           8.2853e-01, -6.9183e-02,  2.8814e-01,  1.1734e-01,  6.8448e-01,\n",
       "          -5.8500e-02,  1.2726e-01,  2.9780e-01,  1.9324e-01,  1.5655e-01,\n",
       "          -9.3004e-03],\n",
       "         [ 1.6984e-01,  3.0993e-02,  8.1557e-01,  6.1679e-01,  1.0429e-01,\n",
       "           7.4573e-01,  2.3072e-02,  3.0572e-01,  5.8163e-02,  5.7122e-01,\n",
       "          -4.5275e-02,  1.5051e-01,  3.2901e-01,  5.6984e-02,  1.0311e-01,\n",
       "          -9.9174e-02],\n",
       "         [ 4.6496e-02,  1.5765e-01,  3.9760e-01,  1.7619e-01, -2.1168e-01,\n",
       "           2.3365e-01, -6.2083e-02,  2.1726e-01, -7.8725e-03,  4.5389e-01,\n",
       "           3.4349e-01, -5.5631e-02,  3.3726e-01, -3.7591e-01, -1.0140e-02,\n",
       "          -4.5806e-01],\n",
       "         [-5.3896e-01,  7.5555e-01,  3.3034e-01, -1.5849e-01, -2.6740e-01,\n",
       "           4.3495e-01,  3.7772e-01,  5.5794e-01, -1.8369e-01,  1.5938e-01,\n",
       "          -2.1042e-01,  5.5790e-02,  6.3184e-01, -6.4884e-01, -9.6084e-02,\n",
       "          -5.0751e-01]],\n",
       "\n",
       "        [[ 6.8925e-02,  1.2248e+00, -4.1194e-01, -1.7046e-01, -6.9224e-01,\n",
       "          -2.9201e-01,  1.2704e+00, -6.8596e-01,  4.3798e-01, -2.6366e-01,\n",
       "           1.1528e-01,  1.1676e+00, -7.2138e-01, -1.2308e+00,  8.3821e-01,\n",
       "          -5.5987e-01],\n",
       "         [-4.6375e-01,  6.3807e-01, -1.5842e-01, -1.3309e-01, -5.9402e-01,\n",
       "          -5.0374e-01,  2.3289e-01, -3.2126e-01,  4.5781e-01, -1.8590e-01,\n",
       "           1.9215e-01,  3.7566e-01, -3.5905e-01, -7.7262e-01,  3.5036e-01,\n",
       "           6.9694e-02],\n",
       "         [-6.4044e-01,  1.3831e-01, -6.1007e-02, -1.1112e-01, -4.5228e-01,\n",
       "          -6.2271e-01, -1.7030e-01, -2.4949e-01,  5.0670e-01, -9.6444e-02,\n",
       "           4.8315e-01,  9.4986e-02, -2.9810e-01, -3.6538e-01,  3.9458e-01,\n",
       "           4.1512e-01],\n",
       "         [-6.7193e-01,  1.2516e-01,  7.3386e-02, -1.3198e-01, -1.7880e-01,\n",
       "          -5.6740e-01, -6.8226e-01,  5.0844e-02,  3.3051e-01,  7.8242e-02,\n",
       "           6.8022e-02, -2.4041e-01, -6.6864e-02, -1.8411e-01, -5.3514e-02,\n",
       "           4.5113e-01],\n",
       "         [-1.4270e-02,  1.0195e+00, -3.4792e-01, -1.6421e-01, -5.5846e-01,\n",
       "          -3.2457e-01,  9.9404e-01, -5.6891e-01,  4.0097e-01, -1.8123e-01,\n",
       "           1.1856e-01,  9.8704e-01, -6.4057e-01, -1.0320e+00,  7.3320e-01,\n",
       "          -4.3167e-01],\n",
       "         [-6.3858e-01, -7.6533e-02, -3.6510e-01,  1.7782e-01, -6.5426e-02,\n",
       "          -3.5158e-01,  7.9591e-02,  1.7384e-01,  3.6676e-01, -4.2302e-02,\n",
       "           2.4923e-01,  4.8239e-01, -2.1295e-01, -2.9492e-01,  3.4749e-01,\n",
       "          -1.7111e-01],\n",
       "         [-2.2366e-01, -5.5317e-02, -1.8296e-01,  2.4258e-01,  2.5357e-01,\n",
       "          -1.6154e-01, -2.3908e-01,  3.3243e-01,  1.0304e-01,  2.6067e-01,\n",
       "          -5.0670e-02,  3.6947e-01, -4.9856e-02,  1.1197e-01,  1.1752e-01,\n",
       "          -2.5078e-01],\n",
       "         [-2.4821e-01,  1.4845e-01, -3.5033e-01,  1.7102e-01,  1.6613e-01,\n",
       "          -2.0643e-01,  8.6633e-02,  8.8414e-02,  2.1188e-01,  2.5805e-01,\n",
       "           5.5145e-02,  4.2668e-01, -2.0443e-01, -1.7372e-01,  3.8899e-01,\n",
       "           5.1725e-02]],\n",
       "\n",
       "        [[ 9.7183e-02,  5.7301e-02, -1.0468e-01, -4.6654e-02, -1.4006e-01,\n",
       "          -8.4126e-01, -1.3625e-01, -6.7465e-01, -2.1541e-01,  1.0993e+00,\n",
       "           2.3427e-01,  3.2605e-02, -1.8521e-01,  1.4780e-01, -6.1045e-01,\n",
       "           1.5391e+00],\n",
       "         [ 1.9305e-01, -2.1031e-01, -3.4658e-01,  2.0567e-01, -1.7798e-01,\n",
       "          -7.4604e-01, -6.4427e-01, -6.9183e-01, -2.0558e-01,  7.0413e-01,\n",
       "           2.3632e-01,  9.8797e-04, -1.7015e-01,  1.1203e-01, -7.1064e-01,\n",
       "           1.2431e+00],\n",
       "         [ 2.9114e-01, -4.8343e-01, -5.9254e-01,  4.6477e-01, -2.1832e-01,\n",
       "          -6.4460e-01, -1.1627e+00, -7.0993e-01, -1.9703e-01,  2.9262e-01,\n",
       "           2.3669e-01, -3.1050e-02, -1.5471e-01,  7.7153e-02, -8.1137e-01,\n",
       "           9.3578e-01],\n",
       "         [ 1.7549e-01, -3.4260e-02, -2.0523e-01,  2.7644e-02, -2.1312e-01,\n",
       "          -5.6022e-01, -3.5273e-01, -6.2722e-01, -3.0037e-01,  4.6061e-01,\n",
       "           1.5004e-01,  1.9040e-02, -1.4646e-01,  1.7220e-01, -6.2559e-01,\n",
       "           1.0722e+00],\n",
       "         [ 1.7354e-01, -1.7962e-01, -2.7874e-01, -1.0590e-01, -1.2952e-01,\n",
       "          -3.5086e-01, -5.5830e-01, -3.8638e-01, -2.9719e-01,  3.3368e-02,\n",
       "           1.7392e-01,  5.5898e-02, -7.2007e-02,  1.3182e-02, -6.6710e-01,\n",
       "           5.4229e-01],\n",
       "         [ 2.4678e-01, -4.7274e-01, -5.2827e-01,  3.1212e-01, -1.7528e-01,\n",
       "          -4.8636e-01, -1.1223e+00, -5.4196e-01, -2.0142e-01,  4.0103e-02,\n",
       "           2.2231e-01, -2.9381e-02, -9.4354e-02,  2.6374e-02, -7.8726e-01,\n",
       "           6.2836e-01],\n",
       "         [-3.9784e-01,  2.5915e-01,  5.0358e-01, -4.6864e-01, -2.2024e-02,\n",
       "          -3.2242e-01, -1.2578e-01,  1.0634e-01,  1.3618e-01,  1.7780e-01,\n",
       "           1.0391e-01, -6.2540e-01,  3.8904e-01,  3.3690e-01, -5.5140e-01,\n",
       "           5.2246e-01],\n",
       "         [-3.5927e-01,  3.3935e-02, -2.9863e-02, -1.5019e-01, -6.0354e-03,\n",
       "          -6.5733e-02, -3.9659e-01, -6.0435e-02, -5.7551e-01, -2.9157e-01,\n",
       "           1.4899e-01, -7.5002e-02,  7.3228e-02, -4.7413e-02, -6.4394e-01,\n",
       "           2.8560e-01]]], grad_fn=<UnsafeViewBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
